import sys
sys.path.append('./')
from utils.result_preprocess import OS_ATLAS_RES_PRE_PROCESS

class OS_ATLAS_Agent:
    def __init__(self, device, accelerator, cache_dir='~/.cache', dropout=0.5, policy_lm=None):
        self.model = None
        self.policy_lm = policy_lm
        self.device = device
        self.accelerator = accelerator

        if "OS-Atlas-Pro-7B" in self.policy_lm:
             from gui_speaker.models.QwenVL import Qwen2_VL_Agent
             self.agent = Qwen2_VL_Agent(device=device, accelerator=accelerator, policy_lm=self.policy_lm)
        else:
             from gui_speaker.models.InternVL import InternVL_Agent
             self.agent = InternVL_Agent(device=device, accelerator=accelerator, policy_lm=self.policy_lm)
        self.res_pre_process = self._res_pre_process()


    def _res_pre_process(self):
        return OS_ATLAS_RES_PRE_PROCESS()
    

    def _load_model(self):
        self.agent.model = self.agent._load_model()

    def get_action(self, obs, args):
        if "7B" in self.policy_lm:
            return self.agent.get_action(obs, args)
        else:
            obs['question'] = obs['messages'][0]['content']
            return self.agent.get_action(obs, args)
            
        
    
    

